# Minimal ADE20K segmentation training with MK-CAViT.
import argparse
import os
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T

from MK_CAViT import mk_cavit_base
from train_utils import set_seed, evaluate_seg


class ADE20KSeg(Dataset):
    """
    Directory structure:
        root/
          images/training/*.jpg
          images/validation/*.jpg
          annotations/training/*.png
          annotations/validation/*.png
    """
    def __init__(self, root: str, split: str = 'training', size: int = 512):
        self.img_dir = os.path.join(root, 'images', split)
        self.ann_dir = os.path.join(root, 'annotations', split)
        self.files = sorted([p for p in os.listdir(self.img_dir) if p.endswith(('.jpg', '.png', '.jpeg'))])
        self.size = size

        self.img_tf = T.Compose([
            T.Resize(size, interpolation=T.InterpolationMode.BILINEAR),
            T.CenterCrop(size),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        self.lbl_tf = T.Compose([
            T.Resize(size, interpolation=T.InterpolationMode.NEAREST),
            T.CenterCrop(size),
        ])

    def __len__(self): return len(self.files)

    def __getitem__(self, idx: int):
        fn = self.files[idx]
        img = Image.open(os.path.join(self.img_dir, fn)).convert('RGB')
        ann_path = os.path.join(self.ann_dir, fn.rsplit('.', 1)[0] + '.png')
        label = Image.open(ann_path)

        img = self.img_tf(img)
        label = self.lbl_tf(label)
        label = torch.from_numpy(np.array(label, dtype=np.int64))
        return img, label


def main(args):
    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = mk_cavit_base(num_classes=150, img_size=args.size).to(device)

    train_set = ADE20KSeg(args.root, 'training', args.size)
    val_set   = ADE20KSeg(args.root, 'validation', args.size)

    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.workers, pin_memory=True)
    val_loader   = DataLoader(val_set, batch_size=args.batch_size, shuffle=False,
                              num_workers=args.workers, pin_memory=True)

    optim = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.05)
    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    for epoch in range(args.epochs):
        model.train()
        tot = 0.0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optim.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=args.amp):
                logits = model.forward_seg(imgs, out_size=labels.shape[-2:])
                loss = F.cross_entropy(logits, labels, ignore_index=255)
            scaler.scale(loss).backward(); scaler.step(optim); scaler.update()
            tot += loss.item() * imgs.size(0)

        ev = evaluate_seg(model, val_loader, device, num_classes=150)
        print(f"[{epoch+1:03d}/{args.epochs:03d}] train loss {tot/len(train_set):.4f} | "
              f"val loss {ev['loss']:.4f} | pixAcc {ev['pixAcc']:.2f}%")

    torch.save(model.state_dict(), args.out)


if __name__ == "__main__":
    p = argparse.ArgumentParser(description="Train MK-CAViT on ADE20K")
    p.add_argument('--root', type=str, required=True)
    p.add_argument('--size', type=int, default=512)
    p.add_argument('--epochs', type=int, default=80)
    p.add_argument('--batch_size', type=int, default=4)
    p.add_argument('--workers', type=int, default=8)
    p.add_argument('--lr', type=float, default=3e-4)
    p.add_argument('--amp', action='store_true')
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--out', type=str, default='mk_cavit_ade20k.pth')
    main(p.parse_args())
